{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,\n",
    "                              TensorDataset)\n",
    "import argparse\n",
    "import logging\n",
    "import os\n",
    "import json\n",
    "import time\n",
    "import torch.nn.functional as F\n",
    "from preprocess import LCSTSProcessor\n",
    "from model import BertAbsSum\n",
    "from pytorch_pretrained_bert.tokenization import BertTokenizer\n",
    "from pytorch_pretrained_bert.modeling import BertModel\n",
    "from pytorch_pretrained_bert.optimization import BertAdam\n",
    "from preprocess import convert_examples_to_features\n",
    "from tqdm import tqdm, trange\n",
    "from transformer import Constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',\n",
    "                    datefmt='%m/%d/%Y %H:%M:%S',\n",
    "                    level=logging.INFO)\n",
    "logger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cal_loss(dratf_logits, refine_logits, ground):\n",
    "    ground = ground[:, 1:]\n",
    "    draft_loss = F.cross_entropy(dratf_logits, ground, ignore_index=Constants.PAD)\n",
    "    refine_loss = F.cross_entropy(refine_logits, ground, ignore_index=Constants.PAD)\n",
    "    return draft_loss + refine_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'data/processed_data'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class ARGS(object):\n",
    "    data_dir = 'data/processed_data'\n",
    "    bert_model = 'chinese_L-12_H-768_A-12'\n",
    "    output_dir = 'output'\n",
    "    GPU_index = 0\n",
    "    learning_rate = 5e-5\n",
    "    num_train_epochs = 3\n",
    "    warmup_proportion = 0.1\n",
    "    max_src_len = 130\n",
    "    max_tgt_len = 30\n",
    "    train_batch_size = 32\n",
    "    decoder_config = None\n",
    "    print_every = 100\n",
    "\n",
    "args = ARGS()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda', args.GPU_index)\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "logger.info(f'Using device:{device}')\n",
    "\n",
    "if not os.path.exists(args.output_dir):\n",
    "    os.makedirs(args.output_dir)\n",
    "model_path = os.path.join(args.output_dir, time.strftime('model_%m-%d-%H:%M:%S', time.localtime()))\n",
    "os.mkdir(model_path)\n",
    "logger.info(f'Saving model to {model_path}.')\n",
    "\n",
    "if args.decoder_config is not None:\n",
    "    with open(args.decoder_config, 'r') as f:\n",
    "        decoder_config = json.load(f)\n",
    "else:\n",
    "    with open(os.path.join(args.bert_model, 'bert_config.json'), 'r') as f:\n",
    "        bert_config = json.load(f)\n",
    "        decoder_config = {}\n",
    "        decoder_config['len_max_seq'] = args.max_tgt_len\n",
    "        decoder_config['d_word_vec'] = bert_config['vocab_size']\n",
    "        decoder_config['n_layers'] = 8\n",
    "        decoder_config['num_head'] = 12\n",
    "        decoder_config['d_k'] = 64\n",
    "        decoder_config['d_v'] = 64\n",
    "        decoder_config['d_model'] = bert_config['hidden_size']\n",
    "        decoder_config['d_inner'] = decoder_config['d_model']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data preprocess\n",
    "processor = LCSTSProcessor()\n",
    "tokenizer = BertTokenizer.from_pretrained(args.bert_model)\n",
    "logger.info('Loading train examples...')\n",
    "train_examples = processor.get_train_examples('data/processed_data')\n",
    "num_train_optimization_steps = int(len(train_examples) / args.train_batch_size) * args.num_train_epochs\n",
    "logger.info('Converting train examples to features...')\n",
    "features = convert_examples_to_features(train_examples, args.max_src_len, args.max_tgt_len, tokenizer)\n",
    "example = train_examples[0]\n",
    "example_feature = features[0]\n",
    "logger.info(\"*** Example ***\")\n",
    "logger.info(\"guid: %s\" % (example.guid))\n",
    "logger.info(\"src text: %s\" % example.src)\n",
    "logger.info(\"src_ids: %s\" % \" \".join([str(x) for x in example_feature.src_ids]))\n",
    "logger.info(\"src_mask: %s\" % \" \".join([str(x) for x in example_feature.src_mask]))\n",
    "logger.info(\"tgt text: %s\" % example.tgt)\n",
    "logger.info(\"tgt_ids: %s\" % \" \".join([str(x) for x in example_feature.tgt_ids]))\n",
    "logger.info(\"tgt_mask: %s\" % \" \".join([str(x) for x in example_feature.tgt_mask]))\n",
    "logger.info('Building dataloader...')\n",
    "all_src_ids = torch.tensor([f.src_ids for f in features], dtype=torch.long)\n",
    "all_src_mask = torch.tensor([f.src_mask for f in features], dtype=torch.long)\n",
    "all_tgt_ids = torch.tensor([f.tgt_ids for f in features], dtype=torch.long)\n",
    "all_tgt_mask = torch.tensor([f.tgt_mask for f in features], dtype=torch.long)\n",
    "train_data = TensorDataset(all_src_ids, all_src_mask, all_tgt_ids, all_tgt_mask)\n",
    "train_sampler = RandomSampler(train_data)\n",
    "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)\n",
    "\n",
    "# model\n",
    "model = BertAbsSum(args.bert_model, decoder_config, )\n",
    "model.to(device)\n",
    "\n",
    "# optimizer\n",
    "param_optimizer = list(model.named_parameters())\n",
    "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n",
    "optimizer_grouped_parameters = [\n",
    "    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},\n",
    "    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]\n",
    "optimizer = BertAdam(optimizer_grouped_parameters,\n",
    "                     lr=args.learning_rate,\n",
    "                     warmup=0.1,\n",
    "                     t_total=num_train_optimization_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(\"***** Running training *****\")\n",
    "logger.info(\"  Num examples = %d\", len(train_examples))\n",
    "logger.info(\"  Batch size = %d\", args.train_batch_size)\n",
    "logger.info(\"  Num steps = %d\", num_train_optimization_steps)\n",
    "model.train()\n",
    "global_step = 0\n",
    "for i in trange(int(args.num_train_epochs), desc=\"Epoch\"):\n",
    "    tr_loss = 0\n",
    "    nb_tr_examples, nb_tr_steps = 0, 0\n",
    "    for step, batch in enumerate(tqdm(train_dataloader, desc=\"Iteration\")):\n",
    "        batch = tuple(t.to(device) for t in batch)\n",
    "        draft_logits, refine_logits = model(*batch)\n",
    "        loss = cal_loss(draft_logits, refine_logits, batch[2])\n",
    "        loss.backward()\n",
    "        tr_loss += loss.item()\n",
    "        nb_tr_examples += batch[0].size(0)\n",
    "        nb_tr_steps += 1\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "        global_step += 1\n",
    "    if step % args.print_every == 0:\n",
    "        logger.info(f'Epoch {i}, step {step}, loss {loss.item()}.')\n",
    "    torch.save(model.state_dict(), os.join(model_path, 'BertAbsSum.bin'))\n",
    "    logger.info(f'Epoch {i} finished. Model saved.')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:sum]",
   "language": "python",
   "name": "conda-env-sum-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
